import os, sys

file_path = os.path.abspath(__file__)
sys.path.append(os.path.abspath(os.path.dirname(file_path)))
import cv2
from torchvision.datasets import CocoDetection
from copy_paste import copy_paste_class

min_keypoints_per_image = 10


def _count_visible_keypoints(anno):
    return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)


def _has_only_empty_bbox(anno):
    return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)


def has_valid_annotation(anno):
    # if it's empty, there is no annotation
    if len(anno) == 0:
        return False
    # if all boxes have close to zero area, there is no annotation
    if _has_only_empty_bbox(anno):
        return False
    # keypoints task have a slight different critera for considering
    # if an annotation is valid
    if "keypoints" not in anno[0]:
        return True
    # for keypoint detection tasks, only consider valid images those
    # containing at least min_keypoints_per_image
    if _count_visible_keypoints(anno) >= min_keypoints_per_image:
        return True

    return False


@copy_paste_class
class CocoDetectionCP(CocoDetection):
    def __init__(
            self,
            root,
            annFile,
            transforms
    ):
        super(CocoDetectionCP, self).__init__(
            root, annFile, None, None, transforms
        )

        # filter images without detection annotations
        ids = []
        for img_id in self.ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
            anno = self.coco.loadAnns(ann_ids)
            if has_valid_annotation(anno):
                ids.append(img_id)
        self.ids = ids

    def load_example(self, index):
        import numpy as np

        img_info = self.coco.loadImgs(index)[0]
        if img_info['id'] != index:  # 判断index和image_id是否一致
            raise ("error index {} with image_id {}".format(index, img_info['id']))

        img_id = index
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        target = self.coco.loadAnns(ann_ids)

        path = self.coco.loadImgs(img_id)[0]['file_name']
        # print(index,path)
        image = cv2.imdecode(np.fromfile(os.path.join(self.root, path), dtype=np.uint8), flags=1)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # convert all of the target segmentations to masks
        # bboxes are expected to be (y1, x1, y2, x2, category_id)
        masks = []
        bboxes = []
        for ix, obj in enumerate(target):
            masks.append(self.coco.annToMask(obj))
            bboxes.append(obj['bbox'] + [obj['category_id']] + [ix])

        # pack outputs into a dict
        output = {
            'image': image,
            'masks': masks,
            'bboxes': bboxes
        }

        return self.transforms(**output)
